{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "02151.ipynb",
      "provenance": [],
      "include_colab_link": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/yananma/5_programs_per_day/blob/master/02151.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dUwdf37pw3Tl",
        "colab_type": "text"
      },
      "source": [
        "## 6.7 门控循环单元 ( GRU )"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "fN59qNU0xBwS",
        "colab_type": "text"
      },
      "source": [
        "### 6.7.2 读取数据集"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gyF9fr55tWjY",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "import numpy as np \n",
        "import torch \n",
        "from torch import nn, optim \n",
        "import torch.nn.functional as F \n",
        "import d2l \n",
        "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "A1pHs0_6wp0M",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!mkdir ../../data "
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zgxokoIzxOno",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "02b77962-402a-4e64-c991-68e22e042fb4"
      },
      "source": [
        "!git clone https://github.com/ShusenTang/Dive-into-DL-PyTorch.git"
      ],
      "execution_count": 5,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'Dive-into-DL-PyTorch'...\n",
            "remote: Enumerating objects: 1692, done.\u001b[K\n",
            "remote: Total 1692 (delta 0), reused 0 (delta 0), pack-reused 1692\u001b[K\n",
            "Receiving objects: 100% (1692/1692), 25.29 MiB | 14.15 MiB/s, done.\n",
            "Resolving deltas: 100% (975/975), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "S-ZdzFBJxTuZ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!cp Dive-into-DL-PyTorch/data/jaychou_lyrics.txt.zip ../../data"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "jpVyKyOjxbfj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "(corpus_indices, char_to_idx, idx_to_char, vocab_size) = d2l.load_data_jay_lyrics()"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hALoj9TexnrA",
        "colab_type": "text"
      },
      "source": [
        "### 6.7.3 从零开始实现"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2aZPx9bHxrx_",
        "colab_type": "text"
      },
      "source": [
        "#### 1 初始化模型参数"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "HwqGmQnXxw7n",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "9ab97bc2-6349-46ba-f90f-509ba3d66521"
      },
      "source": [
        "num_inputs, num_hiddens, num_outputs = vocab_size, 256, vocab_size \n",
        "print('will use', device)\n",
        "\n",
        "def get_params():\n",
        "    def _one(shape):\n",
        "        ts = torch.tensor(np.random.normal(0, 0.01, size=shape), device=device, dtype=torch.float32)\n",
        "        return torch.nn.Parameter(ts, requires_grad=True)\n",
        "    def _three():\n",
        "        return (_one((num_inputs, num_hiddens)), \n",
        "            _one((num_hiddens, num_hiddens)), \n",
        "            torch.nn.Parameter(torch.zeros(num_hiddens, device=device, dtype=torch.float32), requires_grad=True))\n",
        "\n",
        "    W_xz, W_hz, b_z = _three()\n",
        "    W_xr, W_hr, b_r = _three()\n",
        "    W_xh, W_hh, b_h = _three()\n",
        "\n",
        "    W_hq = _one((num_hiddens, num_outputs))\n",
        "    b_q = torch.nn.Parameter(torch.zeros(num_outputs, device=device, dtype=torch.float32), requires_grad=True)\n",
        "    return nn.ParameterList([W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q])"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "will use cuda\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9mgsUbcizZU6",
        "colab_type": "text"
      },
      "source": [
        "#### 2 定义模型"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "zyGgsUCmzdnR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def init_gru_state(batch_size, num_hiddens, device):\n",
        "    return (torch.zeros((batch_size, num_hiddens), device=device), )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "z2sRbkt3zqXq",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def gru(inputs, state, params):\n",
        "    W_xz, W_hz, b_z, W_xr, W_hr, b_r, W_xh, W_hh, b_h, W_hq, b_q = params \n",
        "    H, = state \n",
        "    outputs = []\n",
        "    for X in inputs:\n",
        "        Z = torch.sigmoid(torch.matmul(X, W_xz) + torch.matmul(H, W_hz) + b_z)\n",
        "        R = torch.sigmoid(torch.matmul(X, W_xr) + torch.matmul(H, W_hr) + b_r)\n",
        "        H_tilda = torch.tanh(torch.matmul(X, W_xh) + R * torch.matmul(H, W_hh) + b_h)\n",
        "        H = Z * H + (1 - Z) * H_tilda \n",
        "        Y = torch.matmul(H, W_hq) + b_q\n",
        "        outputs.append(Y)\n",
        "    return outputs, (H, )"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "R_932IYU0uaK",
        "colab_type": "text"
      },
      "source": [
        "#### 3 训练模型并创作歌词"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4MPCU32R00Xj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "num_epochs, num_steps, batch_size, lr, clipping_theta = 160, 35, 32, 1e2, 1e-2 \n",
        "pred_period, pred_len, prefixes = 40, 50, ['分开', '不分开']"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ktBH-aAM1HKs",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "58febb95-7618-4af5-fc5c-03da04c58d44"
      },
      "source": [
        "d2l.train_and_predict_rnn(gru, get_params, init_gru_state, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, \n",
        "            char_to_idx, False, num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, \n",
        "            prefixes)"
      ],
      "execution_count": 13,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 40, perplexity 151.758794, time 0.42 sec\n",
            " - 分开 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我不不 我\n",
            " - 不分开 我想你我 你不我 你不了我 你不了我 你不了我 你不了我 你不了我 你不了我 你不了我 你不了我 \n",
            "epoch 80, perplexity 32.830066, time 0.41 sec\n",
            " - 分开 我想要这样 我不要再想你 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 \n",
            " - 不分开 爱爱我 爱你 我想要这样 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 我不要再想 \n",
            "epoch 120, perplexity 5.056568, time 0.43 sec\n",
            " - 分开我 一场好酒 你来一碗热粥 配色几斤 快使用双截棍 哼哼哈兮 快使用双截棍 哼哼哈兮 快使用双截棍 \n",
            " - 不分开  没有你在我有多难熬多难熬  没有你烦 我有多烦恼  没有你在我有多难熬多难熬  没有你烦 我有多\n",
            "epoch 160, perplexity 1.478603, time 0.42 sec\n",
            " - 分开 让我想大你 这对球 告诉真中 别人止中落不会一力危危 一个风 是果我外的菸以 一起好演日离 我来到\n",
            " - 不分开 你已经离开我 不知不觉 我跟了这节奏 后知后觉 又过了一个秋 后知后觉 我该好好生活 我该好好生活\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MdVpVY6_1lMF",
        "colab_type": "text"
      },
      "source": [
        "### 6.7.4 简洁实现"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ITE2ATJn1wmF",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 218
        },
        "outputId": "02010f71-6ccd-448d-9025-842ecf0cfc38"
      },
      "source": [
        "lr = 1e-2 \n",
        "gru_layer = nn.GRU(input_size=vocab_size, hidden_size=num_hiddens)\n",
        "model = d2l.RNNModel(gru_layer, vocab_size).to(device)\n",
        "d2l.train_and_predict_rnn_pytorch(model, num_hiddens, vocab_size, device, corpus_indices, idx_to_char, char_to_idx, \n",
        "                num_epochs, num_steps, lr, clipping_theta, batch_size, pred_period, pred_len, prefixes)"
      ],
      "execution_count": 14,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "epoch 40, perplexity 1.019791, time 0.06 sec\n",
            " - 分开始的爱写在西元前 深埋在美索不达米亚平原 几十个世纪后出土发现 泥板上的字迹依然清晰可见 我给你的爱\n",
            " - 不分开始打呼 管家是一只会说法语举止优雅的猪 吸血前会念约翰福音做为弥补 拥有一双蓝色眼睛的凯萨琳公主 专\n",
            "epoch 80, perplexity 1.031664, time 0.06 sec\n",
            " - 分开始想像 不想太多 我想一定是我听错弄错搞错 拜托 我想是你的脑袋有问题 随便说说 其实我早已经猜透看\n",
            " - 不分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻 我的伤口被你拆封 誓言太沉重泪被纵容 \n",
            "epoch 120, perplexity 1.010774, time 0.06 sec\n",
            " - 分开始想像 在回忆 的路上 时间变好慢 老街坊 小弄堂 是属于那年代白墙黑瓦的淡淡的忧伤 消失的 旧时光\n",
            " - 不分开始乡相信命运 感谢地心引力 让我碰到你 漂亮的让我面红的可爱女人 温柔的让我心疼的可爱女人 透明的让\n",
            "epoch 160, perplexity 1.013764, time 0.06 sec\n",
            " - 分开始它在空中停留 所有人看着我 抛物线进球 单手过人运球 篮下妙传出手 漂亮的假动作 帅呆了我 全场盯\n",
            " - 不分开始移动 回到当初爱你的时空 停格内容不忠 所有回忆对着我进攻 我的伤口被你拆封 誓言太沉重泪被纵容 \n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}